import pickle


import numpy as np

from ModularUtils.ControllerConstants import generate_permutations


def get_joint_distributions_from_samples(dim_list, corrensponding_samples):
    observe_perms = generate_permutations(dim_list)

    combinations,  count = np.unique(corrensponding_samples, axis=0, return_counts = True)

    upd_dist = {}
    for comb in observe_perms:
        upd_dist[tuple(list(comb))] = 1e-6

    total =corrensponding_samples.shape[0]
    for comb,cnt in zip(combinations,count):
        upd_dist[tuple(list(comb))] =  cnt/total


    return upd_dist


def get_obs_samples(obs_vars):
    sachs_labels= ["Raf" ,"Mek", "Plcg", "PIP2", "PIP3", "Erk", "Akt", "PKA", "PKC", "P38", "Jnk", "INT"]
    file_root = "/path_to_project/CausalSachs/GroundTruth/Dataset/"
    file_name = file_root + "sachs.interventional.txt"
    dataset = np.genfromtxt(file_name, delimiter=" ").astype(int)

    obs_samples = []
    for samp in dataset:
        intv_id = sachs_labels.index("INT")
        if samp[intv_id] == 0:
            obs_samples.append(samp.reshape(1, -1))

    obs_data = np.concatenate(obs_samples, axis=0)

    obs_indices = [sachs_labels.index(lb) for lb in obs_vars]
    obs_data = obs_data[:, obs_indices]-1

    prob_str = "_".join(v for v in obs_vars)
    with open(file_root + f"sachs_P({prob_str}).txt", 'wb') as fp:
        pickle.dump(np.array(obs_data), fp)

    # dist_dict={}
    # dim_list= [3]*len(obs_vars)
    # ret= get_joint_distributions_from_samples(dim_list, obs_data)
    # print(obs_vars, "---> ",ret)
    return


def get_obs_dist(obs_vars, cond_vars):
    sachs_labels= ["Raf" ,"Mek", "Plcg", "PIP2", "PIP3", "Erk", "Akt", "PKA", "PKC", "P38", "Jnk", "INT"]
    file_root = "/path_to_project/CausalSachs/GroundTruth/Dataset/"
    file_name = file_root + "sachs.interventional.txt"

    git_cond_poss= [sachs_labels.index(var) for var in cond_vars]
    dataset = np.genfromtxt(file_name, delimiter=" ").astype(int)

    # all_do_samples = {1: [], 2: [], 3: []}
    all_do_samples = {}
    for samp in dataset:
        intv_id = sachs_labels.index("INT")
        if samp[intv_id] ==0:
            key= tuple(samp[git_cond_poss])
            if key not in all_do_samples:
                all_do_samples[key]=[]
            all_do_samples[key].append(samp.reshape(1, -1))


    dist_dict={}
    for key in all_do_samples:
        each_do_samples = all_do_samples[key]
        if len(each_do_samples)==0:
            continue
        do_samples = np.concatenate(each_do_samples, axis=0)

        # obs_indices = [sachs_labels.index(lb) for lb in obs_vars+[cond_var]]
        obs_indices = [sachs_labels.index(lb) for lb in obs_vars]
        do_data = do_samples[:, obs_indices] - 1

        # dim_list= [3]*len(obs_vars+[cond_var])
        dim_list= [3]*len(obs_vars)
        ret= get_joint_distributions_from_samples(dim_list, do_data)
        print(key, "---> ",do_data.shape, "dist:", ret)
        # print("samples:", do_data[0:10])
        dist_dict[key]=ret

    return dist_dict




def get_do_samples(obs_vars, intv_var):
    sachs_labels= ["Raf" ,"Mek", "Plcg", "PIP2", "PIP3", "Erk", "Akt", "PKA", "PKC", "P38", "Jnk", "INT"]
    file_root = "/path_to_project/CausalSachs/GroundTruth/Dataset/"
    file_name = file_root + "sachs.interventional.txt"
    dataset = np.genfromtxt(file_name, delimiter=" ").astype(int)
    git_intv_pos= sachs_labels.index(intv_var)+1

    all_do_samples = {1:[], 2:[], 3:[]}
    for samp in dataset:
        intv_id = sachs_labels.index("INT")
        if samp[intv_id] == git_intv_pos:
            idx= git_intv_pos-1
            all_do_samples[samp[idx]].append(samp.reshape(1, -1))

    for key in all_do_samples:
        if len(all_do_samples[key])==0:
            continue
        each_do_samples= all_do_samples[key]
        do_samples = np.concatenate(each_do_samples, axis=0)
        obs_indices = [sachs_labels.index(lb) for lb in obs_vars]
        do_data = do_samples[:, obs_indices]-1


        prob_str = "_".join(v for v in obs_vars)
        file_name = file_root + f"sachs_P({prob_str}|do({intv_var}={key-1})).txt"
        with open(file_name , 'wb') as fp:
            pickle.dump(np.array(do_data), fp)




def get_do_dist(obs_vars, intv_var):
    sachs_labels= ["Raf" ,"Mek", "Plcg", "PIP2", "PIP3", "Erk", "Akt", "PKA", "PKC", "P38", "Jnk", "INT"]
    file_root = "/path_to_project/CausalSachs/GroundTruth/Dataset/"
    file_name = file_root + "sachs.interventional.txt"

    git_intv_pos= sachs_labels.index(intv_var)+1
    dataset = np.genfromtxt(file_name, delimiter=" ").astype(int)

    all_do_samples = {1: [], 2: [], 3: []}
    for samp in dataset:
        intv_id = sachs_labels.index("INT")
        if samp[intv_id] == git_intv_pos:
            idx = git_intv_pos - 1
            all_do_samples[samp[idx]].append(samp.reshape(1, -1))


    dist_dict={}
    for key in all_do_samples:
        each_do_samples = all_do_samples[key]
        if len(each_do_samples)==0:
            continue
        do_samples = np.concatenate(each_do_samples, axis=0)

        # obs_indices = [sachs_labels.index(lb) for lb in obs_vars+[intv_var]]
        obs_indices = [sachs_labels.index(lb) for lb in obs_vars]
        do_data = do_samples[:, obs_indices] - 1

        # dim_list= [3]*len(obs_vars+[intv_var])
        dim_list= [3]*len(obs_vars)
        ret= get_joint_distributions_from_samples(dim_list, do_data)
        print(key, "---> ",do_data.shape, "dist:", ret)
        # print("samples:", do_data[0:10])
        dist_dict[key]=ret

    return dist_dict



if __name__ == '__main__':

    # Obs = ["PKA", "PKC", "Raf", "Mek", "Erk"]
    # Obs = ["PKC", "PKA", "Raf", "Mek", "Erk", "Akt"]
    # Obs = ["PKC", "PKA", "Raf", "Mek", "Erk", "Akt"]
    # Obs = ["PIP2","PKA", "PKC", "Raf", "Mek", "Erk","Akt"]
    # Obs = ["PKA","Mek"]

    # ret2 = get_obs_dist(["PKA"], ["PKC"])
    # ret1 = get_do_dist(["PKC"], "PIP2")
    # print("Obs--->",ret2)
    # print("Intv--->",ret1)
    #intv data: PKC, PKA , Mek(do(0)->600)
    # P(Erk|PKA=2) P(Erk|do(PKA=2)) high match
    # P(Erk|PKC=2) P(Erk|do(PKC=2)) high match  full match.


    # get_do_samples(Obs, "PKA")
    # ret= get_do_dist(["Erk"], "PKA")
    # ret= get_do_dist(["Akt"], "PKA")
    # ret= get_do_dist(["Mek"], "PKA")
    # print(ret)


    # bif_file = 'sprinkler'
    # bif_file = 'alarm'
    # bif_file = 'andes'
    # bif_file = 'asia'
    # bif_file = 'pathfinder'
    bif_file = 'sachs'
    # # bif_file = 'miserables'
    # # bif_file = 'filepath/to/model.bif'
    #
    bn=gum.loadBN('data/'+bif_file+'.bif')

    # # Loading DAG with model parameters from bif file.
    model = bn.import_DAG(bif_file)
    print(model)
